# -*- coding: utf-8 -*-
"""
-------------------------------------------------
   File Name：     jruntime_post.py
   Description :
   Author :       liaozhaoyan
   date：          2023/7/31
-------------------------------------------------
   Change Activity:
                   2023/7/31:
-------------------------------------------------
"""
__author__ = 'liaozhaoyan'


import os
import re
import zipfile
from treelib import Tree
from quickSvg import flame
from typing import List
import json

from .base import DiagnosisJobResult, DiagnosisPostProcessor, PostProcessResult


class CparseFold(object):
    def __init__(self):
        super(CparseFold, self).__init__()

    def _addTree(self, tree, node, stack, nr):
        for cid in tree.children(node.identifier):
            if cid.tag == stack:
                cid.data['samples'] += nr
                return cid
        return tree.create_node(tag=stack, parent=node, data={'symbol': stack, 'samples': nr})

    def _addTrees(self, tree, root, stacks, nr):
        node = root
        for i, stack in enumerate(stacks):
            node = self._addTree(tree, node, stack, nr)

    def toTree(self, fold):
        tree = Tree()
        root = tree.create_node(tag="all", parent=tree.root, data={"symbol": "all", "samples": 0})
        with open(fold, 'r') as f:
            for i, line in enumerate(f):
                stacks = line.split(';')
                stacks, nr = stacks[:-1], int(stacks[-1].strip())
                root.data['samples'] += nr
                self._addTrees(tree, root, stacks, nr)
        return tree


def getValue(data):
    return data['samples']


def getNote(tree, node):
    root = tree.get_node(tree.root)
    perRoot = node.data['samples'] * 100.0 / root.data['samples']

    parent = tree.parent(node.identifier)
    if parent is None:
        perParent = 100.0
    else:
        perParent = node.data['samples'] * 100.0 / parent.data['samples']
    return "catch %d samples, %f%% from root, %f%% from parent" % (node.data['samples'], perRoot, perParent)


class CfoldFlame(object):
    def __init__(self):
        super(CfoldFlame, self).__init__()
        self._parse = CparseFold()

    def toFlame(self, fold, save='out.svg'):
        tree = self._parse.toTree(fold)
        root = tree.get_node(tree.root)
        if root.data['samples'] > 0:
            f = flame.Flame(save)
            f.render(tree, getValue, getNote, "java runtime flame.")
        else:
            with open(save, 'w') as fSave:
                fSave.write("")
        return tree


def combineFold(javas, orig="raptor.fold", oFile="combine.fold"):
    lines = []
    pids = {}
    for java in javas:
        fName, _ = os.path.splitext(java)
        names = fName.split("_", 2)
        pid = names[1]
        if len(names) > 2:
            pids[pid] = names[2]
        else:
            pids[pid] = "0"

    try:
        with open(orig, 'r') as f:
            for i, line in enumerate(f):
                pid, rest = line.split("-", 1)
                if pid in pids:
                    if "[unknown]" in line:   # unknown is java stack
                        continue
                    head = "%s-java-%s" % (pid, pids[pid])
                    lines.append(head + ";" + rest.strip())
                else:
                    lines.append(line.strip())
    except FileNotFoundError:
        print("no raptor.fold.")

    for java in javas:
        fName, _ = os.path.splitext(java)
        names = fName.split("_", 2)
        pid = names[1]
        if len(names) > 2:
            cgroup = names[2]
        else:
            cgroup = "0"

        try:
            with open(java, 'r') as f:
                head = "%s-java-%s" % (pid, cgroup)
                for i, line in enumerate(f):
                    stacks, nr = line.strip().rsplit(" ", 1)
                    s = ";".join([head, stacks, " %s" % nr])
                    lines.append(s)
        except FileNotFoundError:
            continue

    with open(oFile, 'w') as f:
        f.write("\n".join(lines))
    return oFile


class CjAna(object):
    def __init__(self):
        self._kernel_proc = {
            r"kcompactd[0-9]+": self._compact,
            r"ksoftirqd/[0-9]+": self._softirq,
            r"kswapd[0-9]+": self._swap,
        }
        self._runtime_proc = {
            r"java-[0-9a-f]{8}": self._jrun,
        }
        self._re_jGC = re.compile(r"GC Thread#[0-9]+")
        pass

    def _clearFlags(self):
        self._flags = {
            "compact": False,
            "rcu_race": False,
            "net_stream": False,
            "swap": False
        }
        self._jFlags = {
            "gc": False,
        }

    def _getChildSyms(self, tree, node):
        syms = []
        for cell in tree.expand_tree(node.identifier):
            n = tree.get_node(cell)
            syms.append(n.tag)
        return syms

    def _j_checkGC(self, tag, res):
        if not self._jFlags['gc']:
            if self._re_jGC.match(tag):
                res.append(u'java 进程在处理GC，建议[执行java栈信息诊断](diagnose/link/jruntime)')
                self._jFlags['gc'] = True

    def _jrun(self, tree, node):
        res = []
        children = tree.children(node.identifier)
        for c in children:
            tag = c.tag
            self._j_checkGC(tag, res)
        return res

    def _swap(self, tree, node):
        if self._flags['swap']:
            return None
        return u"系统在执行内存回收，建议执行[系统内存分布诊断](/diagnose/memory/memgraph)"

    def _compact(self, tree, node):
        if self._flags['compact']:
            return None
        return u"系统内存伙伴子系统正在工作，建议执行[系统内存碎片化诊断](/diagnose/memory/memgraph)"

    def _softirq(self, tree, node):
        syms = self._getChildSyms(tree, node)
        if 'rcu_do_batch' in syms:
            if not self._flags['rcu_race']:
                self._flags['rcu_race'] = True
                return u"rcu 锁处理在内核线程化中处理，建议[执行系统负载诊断](/diagnose/cpu/loadtask)"
            else:
                return None
        elif 'napi_poll' in syms:
            if not self._flags['net_stream']:
                self._flags['net_stream'] = True
                return u"网络收包在内核线程化处理，确认下[当前网络流量](/monitor/node_monitor)"
            else:
                return None

    def _walk_process(self, tree):
        k_res = []
        j_res = []
        root = tree.root
        children = tree.children(root)
        for child in children:
            comms = child.tag
            _, comm = comms.split("-", 1)
            for k, v in self._kernel_proc.items():
                if re.match(k, comm):
                    r = v(tree, child)
                    if r:
                        k_res.append(r)
            for k, v in self._runtime_proc.items():
                if re.match(k, comm):
                    r = v(tree, child)
                    if r:
                        j_res.extend(r)
        return j_res, k_res

    def _maxChild(self, tree, node):
        maxNode = None
        for child in tree.children(node.identifier):
            if maxNode:
                if maxNode.data['samples'] < child.data['samples']:
                    maxNode = child
            else:
                maxNode = child
        return maxNode

    def _topTree(self, tree):
        node = tree.get_node(tree.root)
        stacks = []
        while node:
            stacks.append(node.tag)
            node = self._maxChild(tree, node)
        return stacks

    def _report(self, stacks, j_res, k_res):
        jFlag = False
        mds = []
        if 'java.lang.Thread.run()' in stacks:
            jFlag = True
            pid, _, cgrp = stacks[1].split("-", 2)

            rs = [u"# 热点信息在Java 应用中，请应用开发者关注热点栈信息，调用栈在最末\n",
                  u"* 进程pid: %s" % pid,
                  u"* cgroup id: %s" % cgrp
                  ]
            mds.append("\n".join(rs))

        if len(j_res) > 0:
            rs = [
                u"# java 进程异常分析：\n"
            ]
            for r in j_res:
                rs.append("* " + r)
            mds.append("\n".join(rs))

        if len(k_res) > 0:
            rs = [
                u"# 系统异常分析：\n"
            ]
            for r in k_res:
                rs.append("* " + r)
            mds.append("\n".join(rs))

        if jFlag:
            rs = [
                u"# java热点栈信息：\n",
                '\n```',
            ]
            for s in stacks:
                rs.append(s)
            rs.append('```')
            mds.append("\n".join(rs))
        return "\n\n".join(mds)

    def report(self, tree):
        self._clearFlags()
        stacks = self._topTree(tree)
        j_res, k_res = self._walk_process(tree)
        return self._report(stacks, j_res, k_res)


class CjfrZip(object):
    def __init__(self, path):
        super(CjfrZip, self).__init__()

        pwd = os.getcwd()
        try:
            fPath = os.path.abspath(__file__)
            dPtah = os.path.dirname(fPath)
            os.chdir(dPtah)

            dName = self._unzip(path)
            self._toFold(dName)
            self._combineFlame(dName)
        finally:
            os.chdir(pwd)

    def _unzip(self, path):
        dName, fName = os.path.split(path)
        pwd = os.getcwd()
        os.chdir(dName)
        with zipfile.ZipFile(fName) as f:
            for zName in f.namelist():
                f.extract(zName, "./")
        os.chdir(pwd)
        return os.path.abspath(dName)

    def _transJfr(self, dName, fName):
        name, _ = os.path.splitext(fName)
        iName = os.path.join(dName, fName)
        oName = os.path.join(dName, name + ".fold")
        cmd = "bash ./jfrFlold/jfrparser.sh -e cpu -i %s -o %s" % (iName, oName)
        s = os.popen(cmd, "r", 1)
        s.close()

    def _toFold(self, dName):
        fList = os.listdir(dName)
        pwd = os.getcwd()
        for fName in fList:
            if fName.endswith(".jfr"):
                self._transJfr(dName, fName)
        os.chdir(pwd)

    def _combineFlame(self, dName):
        jList = []

        fList = os.listdir(dName)
        pwd = os.getcwd()
        os.chdir(dName)
        for fName in fList:
            if fName.endswith(".jfr"):
                head, _ = os.path.splitext(fName)
                jList.append("%s.fold" % head)
        fold = combineFold(jList)
        f = CfoldFlame()
        tree = f.toFlame(fold)
        ana = CjAna()
        with open("res.md", 'w') as f:
            report = ana.report(tree)
            if report == "":
                report = u"# java 运行时无明显问题."
            f.write(report)
        os.chdir(pwd)


class PostProcessor(DiagnosisPostProcessor):
    def parse_diagnosis_result(self, results: List[DiagnosisJobResult]) -> PostProcessResult:
        postprocess_result = PostProcessResult(
            code=0,
            err_msg="",
            result={}
        )

        fName = results[0].file_list[0].local_path
        dName = os.path.dirname(fName)
        CjfrZip(fName)
        with open(os.path.join(dName, "out.svg")) as f:
            svg = f.read()
        with open(os.path.join(dName, "res.md")) as f:
            md = f.read()
        with open(os.path.join(dName, "combine.fold")) as f:
            fold = f.read()
        """
        {
             "key": "jruntime_result",
             "type": "markdown",
             "title": "",
             "datasource": "jruntime_data"
        },
        {
            "key": "jruntime_set",
            "type": "svg",
            "title": "运行时诊断结果",
            "datasource": "svgdata"
        }
        """
        postprocess_result.result = {
            "svgdata": {
                "data": [
                    {"key": 0, "value": svg},
                ]
            },
            "fold":{
                "data": fold,
            },
            "jruntime_data": {
                "data": md,
            }
        }
        return postprocess_result
